import pandas as pd
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score, precision_recall_fscore_support, classification_report
from datasets import load_dataset, DatasetDict, Dataset, disable_progress_bars
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
disable_progress_bars()
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################

fr = pd.read_csv('./data/freedom_test.csv')
fr['premise'] = fr['premise'].astype(str)

training_directory ='fewshot'
modname = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"
# instantiate model
model = AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(modname)

def metrics(df, preds, group_by=None):
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='weighted')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = 'max_length', truncation = True)


def compute_metrics_standard(eval_pred, label_text_alphabetical=list(model.config.id2label.values())):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro') 
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)
    mcc = matthews_corrcoef(labels, preds_max)

    metrics = {'MCC': mcc,
            'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

#############
## Zero Shot
#############
pipe = pipeline("zero-shot-classification", model = modname, device = 0, batch_size = 32)
res = pipe(list(fr['premise'].str.lower()), ['freedom and rights except voting.'], hypothesis_template = 'This text is about {}', multi_label = False)
labels = [round(label['scores'][0], 0) for label in res]
fr['0_shot'] = labels
fr['0_shot'].replace({0:1, 1:0}, inplace = True)

zs_fr = pd.DataFrame({'n':0, 'mcc':matthews_corrcoef(fr['entailment'], fr['0_shot']), 
                      'accuracy':accuracy_score(fr['entailment'], fr['0_shot']),
                      'f1': f1_score(fr['entailment'], fr['0_shot'])}, index = [0])
zs_fr

#############
## Few Shot
#############
training_args = TrainingArguments(
    output_dir=training_directory,
    logging_dir=f'{training_directory}/logs',
    lr_scheduler_type= "linear",
    group_by_length=False,
    learning_rate = 9e-6,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 16,
    gradient_accumulation_steps = 1, 
    num_train_epochs=5,
    warmup_ratio=0.06,  
    weight_decay=0.01, 
    fp16= True,
    fp16_full_eval= True,
    eval_strategy="no",
    seed=1,
    save_strategy="no",
    dataloader_num_workers = 0,
    disable_tqdm=True
    )
# Define a function to initialize the modelin the trainer. This will make results reproducible
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

# Define the number of samples (shots) and random seeds to use
shots = [10, 25, 50, 100]

# Initialize lists to store results
mcc_list = []
acc_list = []
f1_list = []
shots_list = []

# Iterate through different shot sizes
for shot in shots:
    # Iterate through different random seeds
    for seed in range(0,20):
        print("DEBATE Large seed {}, {} shots".format(seed, shot))
        # Sample training data based on current shot size and seed
        train = fr.sample(shot, random_state = seed)
        # Create validation set with remaining instances
        val = fr[~fr.index.isin(train.index)]
        
        # Create a DatasetDict with train and validation splits
        ds = DatasetDict({'train': Dataset.from_pandas(train, preserve_index=False), 'validation':Dataset.from_pandas(val, preserve_index=False)})
        # Tokenize the dataset
        dstok = ds.map(tokenize_function, batched = True)
        # Rename 'entailment' column to 'label'
        dstok = dstok.rename_columns({'entailment':'label'})
        # Define label mapping
        id2label = {0: "entailment", 1: "not_entailment"}
        
        # Initialize the Trainer
        trainer = Trainer(
            model_init = model_init,
            tokenizer=tokenizer,
            args=training_args,
            train_dataset=dstok['train'],
            eval_dataset=dstok['validation'],
            compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(model.config.id2label.values()))
        )
        
        # Train the model
        trainer.train()
        # Make predictions on validation set
        res = trainer.predict(dstok['validation'])
        preds = np.argmax(res.predictions, axis=-1)
        
        # Calculate Matthews Correlation Coefficient
        mcc_res = matthews_corrcoef(val['entailment'], preds)
        mcc_list.append(mcc_res)
        # Calculate Accuracy
        acc_res = accuracy_score(val['entailment'], preds)
        acc_list.append(acc_res)
        # Calculate F1
        f1_res = f1_score(val['entailment'], preds)
        f1_list.append(f1_res)
        # Store the current shot size
        shots_list.append(shot)
    
######################
## Compile and export
######################
res_fr = pd.DataFrame({'n':shots_list, 'f1':f1_list, 'mcc': mcc_list, 'accuracy':acc_list})
res_fr = pd.concat([zs_fr, res_fr], axis = 0)
res_fr.to_csv('./data/motn_fewshot_large.csv', index = False)